Skip to content

feat(gdn): separate input and output pool indices#2905

Open
feldsherov wants to merge 1 commit intoflashinfer-ai:mainfrom
feldsherov:gdn-decode-separate-input-and-output-indices
Open

feat(gdn): separate input and output pool indices#2905
feldsherov wants to merge 1 commit intoflashinfer-ai:mainfrom
feldsherov:gdn-decode-separate-input-and-output-indices

Conversation

@feldsherov
Copy link
Copy Markdown
Contributor

@feldsherov feldsherov commented Mar 28, 2026

📌 Description

Introduce separate output indices parameter for gated_delta_rule_decode_pretranspose.

This addresses decoded part of feature request in #2873

🔍 Related Issues

#2873

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

I've checked only tests/gdn/test_decode_delta_rule.py on H200. I need help with running whole testsuite.

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Optional control to write updated recurrent states to caller-specified output slots (separate from read/input slots); permitted only in pool (initial-state) mode and includes input validation while preserving existing behavior when not used.
  • Tests
    • Added tests covering separate read/write state indexing and the case where output indices equal input indices, validating outputs and pool mutations.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 28, 2026

📝 Walkthrough

Walkthrough

Adds an optional output_state_indices argument across the gated-delta-rule decode path and kernels so updated recurrent states can be written to caller-specified pool slots different from the read indices. The argument is validated for pool mode and plumbed through BF16 MTP and pretranspose kernels; tests cover separate read/write indexing.

Changes

Cohort / File(s) Summary
Public API
flashinfer/gdn_decode.py
Added output_state_indices: Optional[torch.Tensor] = None to gated_delta_rule_decode_pretranspose(); validate that it's only used with pool-mode (initial_state/initial_state_indices) and that it has shape [B] and integer dtype.
BF16 MTP kernel + wrapper
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Extended MTP kernel and launch wrapper to accept h0_out_indices (write indices). Kernel writeback now uses h0_out_indices (write mapping) while h0_indices remains the read mapping. Public gated_delta_rule_mtp(...) accepts optional output_state_indices and forwards it.
Pretranspose kernels + launcher
flashinfer/gdn_kernels/gdn_decode_pretranspose.py
Added h0_out_indices parameter to small/big pretranspose kernels and their launchers. Kernels compute separate read (h0_indices) and write (h0_out_indices) pool indices; run_pretranspose_decode() accepts output_state_indices (defaults to read indices when absent).
Tests
tests/gdn/test_decode_delta_rule.py
Added SM90+ tests test_output_state_indices and test_output_state_indices_same_as_input to validate distinct read/write behavior and equivalence when output equals input.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant API as gated_delta_rule_decode_pretranspose
    participant Runner as run_pretranspose_decode / gated_delta_rule_mtp
    participant Kernel as CUDA Kernel (pretranspose / bf16 MTP)
    participant Pool as State Pool

    User->>API: call with initial_state (+pool), initial_state_indices (read), output_state_indices (write)
    API->>API: validate output_state_indices (pool-mode, shape [B], int32/64)
    API->>Runner: forward tensors and indices
    Runner->>Kernel: launch with h0_indices (read) and h0_out_indices (write)
    Kernel->>Pool: read state from Pool[h0_indices[b]]
    Kernel->>Kernel: compute gated-delta updates
    Kernel->>Pool: write updated state to Pool[h0_out_indices[b]]
    Pool-->>User: outputs and mutated pool
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~40 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • bkryu
  • yongwww
  • kahyunnam

Poem

🐰 I hop and I map every index with care,
Read from one burrow, write to another fair.
Pools now split paths, no accidental blends,
Tiny rabbit tweaks so each state finds new friends. 🥕

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main feature: separating input and output pool indices for the gated delta rule decode function.
Description check ✅ Passed The description addresses the template structure with a summary, related issue, and completed pre-commit and test checklist items; however, not all tests are confirmed passing.
Docstring Coverage ✅ Passed Docstring coverage is 90.91% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the output_state_indices parameter to the Gated Delta Rule decode kernels, enabling the updated state to be written to a different pool slot than the one read from. The changes include updates to the high-level Python API, the underlying CUDA kernels in gdn_decode_bf16_state.py and gdn_decode_pretranspose.py, and the addition of verification tests. Review feedback recommends grouping index reads within the kernels to improve consistency and instruction-level parallelism, as well as simplifying redundant logic in the pretranspose runner.

Comment on lines +754 to +759
pool_batch_idx = gH_slot_indices[batch_idx]
if pool_batch_idx < 0:
pool_batch_idx = cutlass.Int32(0)
write_pool_batch_idx = gH_out_slot_indices[batch_idx]
if write_pool_batch_idx < 0:
write_pool_batch_idx = cutlass.Int32(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for handling negative indices is duplicated in gated_delta_rule_decode_kernel_seqlen234_unified and gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk. To improve consistency and potentially instruction-level parallelism, consider grouping the reads together before the checks, as done in gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk.

Additionally, to reduce code duplication across kernels, you could introduce a cute.jit helper function at the module level to handle this pattern.

Suggested change
pool_batch_idx = gH_slot_indices[batch_idx]
if pool_batch_idx < 0:
pool_batch_idx = cutlass.Int32(0)
write_pool_batch_idx = gH_out_slot_indices[batch_idx]
if write_pool_batch_idx < 0:
write_pool_batch_idx = cutlass.Int32(0)
pool_batch_idx = gH_slot_indices[batch_idx]
write_pool_batch_idx = gH_out_slot_indices[batch_idx]
if pool_batch_idx < 0:
pool_batch_idx = cutlass.Int32(0)
if write_pool_batch_idx < 0:
write_pool_batch_idx = cutlass.Int32(0)

Comment on lines +976 to +979
if use_pool_indexing and output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use_pool_indexing check here is redundant. The public API gated_delta_rule_decode_pretranspose already asserts that output_state_indices can only be provided when use_pool_indexing is true.

You can simplify this logic for better readability.

Suggested change
if use_pool_indexing and output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices
if output_state_indices is not None:
h0_out_indices = output_state_indices.to(torch.int32)
else:
h0_out_indices = h0_indices

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 195-206: The output_state_indices path currently allows in-place
remaps that can alias other batch items' source slots, making final state
CTA-order dependent; in the block that checks output_state_indices (and uses
use_pool and initial_state / initial_state_indices), validate that
output_state_indices contains no duplicate targets and that none of its target
indices overlap any indices in initial_state_indices (or raise a clear error);
alternatively implement a staged fallback: allocate a temporary buffer, gather
sources into temp using initial_state_indices, perform compute, then scatter
results from temp to initial_state using output_state_indices to avoid
read/write races. Ensure checks/reference to output_state_indices,
initial_state, initial_state_indices and use_pool are used so the change locates
the remap logic.
- Around line 195-206: The code currently only checks shape/dtype of
output_state_indices; add validation that output_state_indices is on the same
device as the pool (reject CPU/non-local tensors) and that all values are within
[0, pool_size-1] to prevent out-of-bounds or aliasing when writing into the pool
(when use_pool/initial_state is active). In the gdn_decode logic where
output_state_indices is handled (the block that asserts use_pool and checks
shape/dtype), add checks for device equality to the pool tensor and use
torch.any((idx < 0) | (idx >= pool_size)) or equivalent to raise a clear
ValueError/Assertion if any index is out of range; keep references to
output_state_indices, use_pool, pool_size, and initial_state so the guard runs
early and fails fast.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1756-1767: The tests currently use torch.testing.assert_close to
check that pool_under_test[read_indices] and pool_under_test[~used_mask] match
pool_orig with nonzero atol/rtol; change these to exact-equality checks (e.g.,
use torch.equal or torch.testing.assert_close(..., atol=0, rtol=0)) for the two
assertions involving pool_under_test, pool_orig, read_indices and the computed
used_mask/write_indices so any stray mutation is caught.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 748e9c86-0f49-4322-be15-919a3f8c3f91

📥 Commits

Reviewing files that changed from the base of the PR and between 904fa8c and 93687a1.

📒 Files selected for processing (4)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
  • tests/gdn/test_decode_delta_rule.py

Comment thread flashinfer/gdn_decode.py
Comment on lines +195 to +206
if output_state_indices is not None:
assert use_pool, (
"output_state_indices can only be used with initial_state (pool mode)"
)
assert output_state_indices.shape == (B,), (
f"Expected output_state_indices shape [{B}], "
f"got {output_state_indices.shape}"
)
assert output_state_indices.dtype in (torch.int32, torch.int64), (
f"output_state_indices must be int32 or int64, "
f"got {output_state_indices.dtype}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reject in-place remaps that alias another batch item's source slot.

output_state_indices still writes back into the same initial_state buffer during the same kernel launch. If two batch items target the same write slot, or one item writes a slot another item is still reading via initial_state_indices, the final state becomes CTA-order dependent and no longer matches gather→compute→scatter semantics. Please either validate a safe mapping here or route overlapping remaps through a staged fallback.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 195 - 206, The output_state_indices
path currently allows in-place remaps that can alias other batch items' source
slots, making final state CTA-order dependent; in the block that checks
output_state_indices (and uses use_pool and initial_state /
initial_state_indices), validate that output_state_indices contains no duplicate
targets and that none of its target indices overlap any indices in
initial_state_indices (or raise a clear error); alternatively implement a staged
fallback: allocate a temporary buffer, gather sources into temp using
initial_state_indices, perform compute, then scatter results from temp to
initial_state using output_state_indices to avoid read/write races. Ensure
checks/reference to output_state_indices, initial_state, initial_state_indices
and use_pool are used so the change locates the remap logic.

⚠️ Potential issue | 🔴 Critical

Validate output_state_indices against the pool before dispatch.

The new arg is only shape/dtype-checked. A CPU tensor here will fail late, and a negative or >= pool_size write index can either become an out-of-bounds store on the float32 pretranspose path or silently alias slot 0 on the bf16 path. Please reject non-local or out-of-range write indices here unless you want explicit write-side padding semantics.

💡 Suggested guard
     if output_state_indices is not None:
         assert use_pool, (
             "output_state_indices can only be used with initial_state (pool mode)"
         )
         assert output_state_indices.shape == (B,), (
             f"Expected output_state_indices shape [{B}], "
             f"got {output_state_indices.shape}"
         )
         assert output_state_indices.dtype in (torch.int32, torch.int64), (
             f"output_state_indices must be int32 or int64, "
             f"got {output_state_indices.dtype}"
         )
+        assert output_state_indices.device == initial_state.device, (
+            "output_state_indices must be on the same device as initial_state"
+        )
+        pool_size = int(initial_state.shape[0])
+        in_range = (output_state_indices >= 0) & (output_state_indices < pool_size)
+        assert in_range.all().item(), (
+            f"output_state_indices must be in [0, {pool_size})"
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 195 - 206, The code currently only
checks shape/dtype of output_state_indices; add validation that
output_state_indices is on the same device as the pool (reject CPU/non-local
tensors) and that all values are within [0, pool_size-1] to prevent
out-of-bounds or aliasing when writing into the pool (when
use_pool/initial_state is active). In the gdn_decode logic where
output_state_indices is handled (the block that asserts use_pool and checks
shape/dtype), add checks for device equality to the pool tensor and use
torch.any((idx < 0) | (idx >= pool_size)) or equivalent to raise a clear
ValueError/Assertion if any index is out of range; keep references to
output_state_indices, use_pool, pool_size, and initial_state so the guard runs
early and fails fast.

Comment on lines +1756 to +1767
# Read slots must be unchanged (we wrote to different slots)
torch.testing.assert_close(
pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
)

# Other slots must be unchanged
used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
used_mask[read_indices] = True
used_mask[write_indices] = True
torch.testing.assert_close(
pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use exact equality for slots that must stay untouched.

These assertions are checking for no mutation, not numerical closeness. Keeping atol/rtol=1e-3 can hide a small stray write, so the read slots and the untouched remainder should be compared with zero tolerance.

💡 Tighten the unchanged-slot assertions
     # Read slots must be unchanged (we wrote to different slots)
     torch.testing.assert_close(
-        pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
+        pool_under_test[read_indices],
+        pool_orig[read_indices],
+        atol=0.0,
+        rtol=0.0,
     )

     # Other slots must be unchanged
     used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
     used_mask[read_indices] = True
     used_mask[write_indices] = True
     torch.testing.assert_close(
-        pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
+        pool_under_test[~used_mask],
+        pool_orig[~used_mask],
+        atol=0.0,
+        rtol=0.0,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 1756 - 1767, The tests
currently use torch.testing.assert_close to check that
pool_under_test[read_indices] and pool_under_test[~used_mask] match pool_orig
with nonzero atol/rtol; change these to exact-equality checks (e.g., use
torch.equal or torch.testing.assert_close(..., atol=0, rtol=0)) for the two
assertions involving pool_under_test, pool_orig, read_indices and the computed
used_mask/write_indices so any stray mutation is caught.

@saltyminty
Copy link
Copy Markdown
Collaborator

saltyminty commented Apr 1, 2026

Approved conditional on CI.

Edit: though it seems I don't have write access so will need another reviewer to take a look

@saltyminty
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !488 has been created, and the CI pipeline #47476702 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47476702: 7/20 passed

@feldsherov
Copy link
Copy Markdown
Contributor Author

feldsherov commented Apr 2, 2026

@saltyminty thank you for the review!

I see flashinfer-bot is reporting failed CI pipeline. I am happy to address any issues, but I don't have access to the CI results.
Can you help me in any way?

@feldsherov
Copy link
Copy Markdown
Contributor Author

@kahyunnam thank you for the review!

@kahyunnam @saltyminty what should I do to land this PR?

@feldsherov feldsherov force-pushed the gdn-decode-separate-input-and-output-indices branch from 93687a1 to dc5af70 Compare April 6, 2026 14:26
@feldsherov
Copy link
Copy Markdown
Contributor Author

In the meanwhile, I rebased to the latest main and adjusted gdn_decode_bf16state_mtp_kernel to support the change.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: When defaulting output_state_indices (when
output_state_indices is None), preserve padding/null-buffer semantics by cloning
initial_state_indices but mapping padded markers (-1) back to the fallback write
slot (e.g., 0) before use; specifically, in the block handling
output_state_indices, set output_state_indices = initial_state_indices.clone(),
then replace any entries equal to -1 with 0, and finally ensure dtype is
torch.int32. This keeps the kernel's h0_out_indices behavior correct (padded
reads won't write to -1 locations) while keeping the int32 conversion logic.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 63235c41-a9c7-4f65-a721-9399edaa4c08

📥 Commits

Reviewing files that changed from the base of the PR and between 93687a1 and dc5af70.

📒 Files selected for processing (4)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
  • tests/gdn/test_decode_delta_rule.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py

Comment on lines +2549 to +2553
# Default output indices to read indices
if output_state_indices is None:
output_state_indices = initial_state_indices
elif output_state_indices.dtype != torch.int32:
output_state_indices = output_state_indices.to(torch.int32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Preserve padding/null-buffer semantics when defaulting output_state_indices.

This regresses the existing BF16 negative-index path: padded reads still come in as initial_state_indices == -1, but None now copies that -1 straight onto the write side. The kernel uses h0_out_indices for final writeback, so padded rows now write before h0_source instead of falling back to slot 0.

🐛 Minimal fix
-    if output_state_indices is None:
-        output_state_indices = initial_state_indices
-    elif output_state_indices.dtype != torch.int32:
-        output_state_indices = output_state_indices.to(torch.int32)
+    if output_state_indices is None:
+        # Preserve the existing slot-0 null-buffer behavior for padded rows.
+        output_state_indices = initial_state_indices.clamp_min(0)
+    if output_state_indices.dtype != torch.int32:
+        output_state_indices = output_state_indices.to(torch.int32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553,
When defaulting output_state_indices (when output_state_indices is None),
preserve padding/null-buffer semantics by cloning initial_state_indices but
mapping padded markers (-1) back to the fallback write slot (e.g., 0) before
use; specifically, in the block handling output_state_indices, set
output_state_indices = initial_state_indices.clone(), then replace any entries
equal to -1 with 0, and finally ensure dtype is torch.int32. This keeps the
kernel's h0_out_indices behavior correct (padded reads won't write to -1
locations) while keeping the int32 conversion logic.

@saltyminty saltyminty force-pushed the gdn-decode-separate-input-and-output-indices branch from dc5af70 to c8d21f0 Compare April 8, 2026 17:24
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (4)
tests/gdn/test_decode_delta_rule.py (1)

2182-2193: ⚠️ Potential issue | 🟡 Minor

Use exact equality for untouched slots.

These assertions are checking for no mutation, not closeness. Keeping atol/rtol=1e-3 can hide a small stray write in the read slots or the untouched remainder.

Tighten the unchanged-slot checks
     # Read slots must be unchanged (we wrote to different slots)
     torch.testing.assert_close(
-        pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
+        pool_under_test[read_indices],
+        pool_orig[read_indices],
+        atol=0.0,
+        rtol=0.0,
     )
 
     # Other slots must be unchanged
     used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
     used_mask[read_indices] = True
     used_mask[write_indices] = True
     torch.testing.assert_close(
-        pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
+        pool_under_test[~used_mask],
+        pool_orig[~used_mask],
+        atol=0.0,
+        rtol=0.0,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 2182 - 2193, Change the
"unchanged slots" checks to require exact equality instead of approximate
closeness: for the read-only check comparing pool_under_test[read_indices] and
pool_orig[read_indices] and for the unused-slot check comparing
pool_under_test[~used_mask] and pool_orig[~used_mask], replace the current
torch.testing.assert_close usage (which uses atol/rtol) with an exact-equality
assertion (e.g., use torch.testing.assert_close with rtol=0 and atol=0 or
torch.testing.assert_equal / torch.equal) so any stray writes to pool_under_test
(referenced by pool_under_test, pool_orig, read_indices, write_indices,
used_mask) will fail the test.
flashinfer/gdn_kernels/gdn_decode_bf16_state.py (1)

2549-2553: ⚠️ Potential issue | 🔴 Critical

Preserve null-buffer semantics when defaulting write indices.

When output_state_indices is omitted, this copies -1 padding markers from initial_state_indices onto the write side. The kernel clamps cache_idx, but final writeback still uses write_cache_idx, so padded rows can store before the pool instead of slot 0.

Minimal fix
-    if output_state_indices is None:
-        output_state_indices = initial_state_indices
-    elif output_state_indices.dtype != torch.int32:
-        output_state_indices = output_state_indices.to(torch.int32)
+    if output_state_indices is None:
+        # Preserve the slot-0 null-buffer behavior for padded rows.
+        output_state_indices = initial_state_indices.clamp_min(0)
+    if output_state_indices.dtype != torch.int32:
+        output_state_indices = output_state_indices.to(torch.int32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553,
The bug is that defaulting output_state_indices to initial_state_indices copies
-1 padding markers to the write side, allowing padded rows to be written;
instead, when output_state_indices is None set it to a new int32 tensor of the
same shape filled with -1 (preserving null-buffer/write-disabled semantics)
rather than aliasing initial_state_indices; ensure subsequent dtype logic still
converts tensors to torch.int32 (use torch.full_like(initial_state_indices, -1,
dtype=torch.int32) or torch.full(initial_state_indices.shape, -1,
dtype=torch.int32) for the assignment to output_state_indices).
flashinfer/gdn_decode.py (2)

199-210: ⚠️ Potential issue | 🔴 Critical

Reject unsafe pool remaps.

output_state_indices still allows duplicate destinations and cross-batch read/write overlap. In the in-place pool path that makes the final state CTA-order dependent instead of equivalent to gather→compute→scatter. Please reject those mappings here, or route remaps through a staged buffer.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 199 - 210, The current
output_state_indices validation allows duplicate targets and cross-batch
overlaps which makes the in-place pool path CTA-order dependent; update the
check in the block that currently validates output_state_indices (the code
around the output_state_indices assertions in gdn_decode.py) to reject any
non-permutation mapping: assert that all values are in range [0, B) and that
torch.unique(output_state_indices).numel() == B (i.e., no duplicates and full
bijection), and raise a clear assertion/error message like "output_state_indices
must be a permutation (no duplicates or cross-batch overlaps) when using pool
mode"; alternatively if you prefer to support non-permutations implement a
staged-buffer path (copy to a temporary buffer then scatter) and route
non-permutation remaps through that path instead of the in-place pool path.

199-210: ⚠️ Potential issue | 🔴 Critical

Fail fast on invalid destination slots.

This still only checks shape/dtype. A CPU tensor, a negative destination, or an index >= pool_size can reach the kernels and turn into an invalid state write.

Suggested guard
     if output_state_indices is not None:
         assert use_pool, (
             "output_state_indices can only be used with initial_state (pool mode)"
         )
         assert output_state_indices.shape == (B,), (
             f"Expected output_state_indices shape [{B}], "
             f"got {output_state_indices.shape}"
         )
         assert output_state_indices.dtype in (torch.int32, torch.int64), (
             f"output_state_indices must be int32 or int64, "
             f"got {output_state_indices.dtype}"
         )
+        assert output_state_indices.device == initial_state.device, (
+            "output_state_indices must be on the same device as initial_state"
+        )
+        pool_size = int(initial_state.shape[0])
+        in_range = (output_state_indices >= 0) & (
+            output_state_indices < pool_size
+        )
+        assert torch.all(in_range).item(), (
+            f"output_state_indices must be in [0, {pool_size})"
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 199 - 210, Add strict runtime guards
for output_state_indices: when output_state_indices is not None (and use_pool is
true), assert it is on the same device as the pool/initial_state (or at least a
CUDA device, not CPU), and assert all indices are within [0, pool_size-1] (no
negatives and none >= pool_size) before passing to kernels; use tensor
operations like output_state_indices.min() and output_state_indices.max() (or
torch.any checks) to detect out-of-range values and raise clear AssertionError
messages referencing output_state_indices, use_pool, and pool_size so invalid
destination slots cannot reach the GPU kernels.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 199-210: The current output_state_indices validation allows
duplicate targets and cross-batch overlaps which makes the in-place pool path
CTA-order dependent; update the check in the block that currently validates
output_state_indices (the code around the output_state_indices assertions in
gdn_decode.py) to reject any non-permutation mapping: assert that all values are
in range [0, B) and that torch.unique(output_state_indices).numel() == B (i.e.,
no duplicates and full bijection), and raise a clear assertion/error message
like "output_state_indices must be a permutation (no duplicates or cross-batch
overlaps) when using pool mode"; alternatively if you prefer to support
non-permutations implement a staged-buffer path (copy to a temporary buffer then
scatter) and route non-permutation remaps through that path instead of the
in-place pool path.
- Around line 199-210: Add strict runtime guards for output_state_indices: when
output_state_indices is not None (and use_pool is true), assert it is on the
same device as the pool/initial_state (or at least a CUDA device, not CPU), and
assert all indices are within [0, pool_size-1] (no negatives and none >=
pool_size) before passing to kernels; use tensor operations like
output_state_indices.min() and output_state_indices.max() (or torch.any checks)
to detect out-of-range values and raise clear AssertionError messages
referencing output_state_indices, use_pool, and pool_size so invalid destination
slots cannot reach the GPU kernels.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: The bug is that defaulting output_state_indices to
initial_state_indices copies -1 padding markers to the write side, allowing
padded rows to be written; instead, when output_state_indices is None set it to
a new int32 tensor of the same shape filled with -1 (preserving
null-buffer/write-disabled semantics) rather than aliasing
initial_state_indices; ensure subsequent dtype logic still converts tensors to
torch.int32 (use torch.full_like(initial_state_indices, -1, dtype=torch.int32)
or torch.full(initial_state_indices.shape, -1, dtype=torch.int32) for the
assignment to output_state_indices).

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 2182-2193: Change the "unchanged slots" checks to require exact
equality instead of approximate closeness: for the read-only check comparing
pool_under_test[read_indices] and pool_orig[read_indices] and for the
unused-slot check comparing pool_under_test[~used_mask] and
pool_orig[~used_mask], replace the current torch.testing.assert_close usage
(which uses atol/rtol) with an exact-equality assertion (e.g., use
torch.testing.assert_close with rtol=0 and atol=0 or torch.testing.assert_equal
/ torch.equal) so any stray writes to pool_under_test (referenced by
pool_under_test, pool_orig, read_indices, write_indices, used_mask) will fail
the test.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cd63ccca-d797-4d84-ac61-caebd3460c40

📥 Commits

Reviewing files that changed from the base of the PR and between dc5af70 and c8d21f0.

📒 Files selected for processing (4)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
  • tests/gdn/test_decode_delta_rule.py

@saltyminty saltyminty force-pushed the gdn-decode-separate-input-and-output-indices branch from c8d21f0 to 58cc26d Compare April 8, 2026 23:57
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
tests/gdn/test_decode_delta_rule.py (1)

2182-2193: ⚠️ Potential issue | 🟡 Minor

Use exact equality for slots that must remain untouched.

These assertions verify that unmodified pool slots remain unchanged. Using atol=1e-3, rtol=1e-3 can mask small stray writes. For immutability checks, use zero tolerance.

💡 Tighten the unchanged-slot assertions
     # Read slots must be unchanged (we wrote to different slots)
     torch.testing.assert_close(
-        pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol
+        pool_under_test[read_indices],
+        pool_orig[read_indices],
+        atol=0.0,
+        rtol=0.0,
     )

     # Other slots must be unchanged
     used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device)
     used_mask[read_indices] = True
     used_mask[write_indices] = True
     torch.testing.assert_close(
-        pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol
+        pool_under_test[~used_mask],
+        pool_orig[~used_mask],
+        atol=0.0,
+        rtol=0.0,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 2182 - 2193, The assertions
that verify unmodified slots use non-zero tolerances which can hide stray
writes; change the checks on pool_under_test vs pool_orig for read_indices and
~used_mask to use exact equality (e.g., torch.equal or
torch.testing.assert_close with atol=0, rtol=0) so that
pool_under_test[read_indices] and pool_under_test[~used_mask] must match
pool_orig exactly; update the two assert calls referencing pool_under_test,
pool_orig, read_indices, write_indices, and used_mask accordingly.
flashinfer/gdn_kernels/gdn_decode_bf16_state.py (1)

2549-2553: ⚠️ Potential issue | 🔴 Critical

Preserve padding/null-buffer semantics when defaulting output_state_indices.

When initial_state_indices contains -1 (padding markers), directly assigning it to output_state_indices will cause the kernel to compute flat_write_idx = -1 * HV + i_hv (a negative offset) and write to invalid memory at lines 1981-2012. The read path has protection (lines 1227-1229 clamp negative cache_idx to 0), but the write path has no equivalent guard.

🐛 Proposed fix: clamp negative indices to slot 0 for writes
     # Default output indices to read indices
     if output_state_indices is None:
-        output_state_indices = initial_state_indices
-    elif output_state_indices.dtype != torch.int32:
+        # Preserve the existing slot-0 null-buffer behavior for padded rows.
+        output_state_indices = initial_state_indices.clamp(min=0)
+    if output_state_indices.dtype != torch.int32:
         output_state_indices = output_state_indices.to(torch.int32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2549 - 2553,
When defaulting output_state_indices to initial_state_indices, avoid assigning
the same tensor with -1 padding because the write path computes flat_write_idx
and will write to negative offsets; instead, in the branch where
output_state_indices is None, create a clone of initial_state_indices, replace
negative values (e.g. -1) with 0 to preserve the null-slot semantics for writes,
and then ensure the tensor is converted to torch.int32 (matching the existing
dtype-check branch). Update the code that sets output_state_indices so it uses
output_state_indices = initial_state_indices.clone();
output_state_indices[output_state_indices < 0] = 0; output_state_indices =
output_state_indices.to(torch.int32) (or equivalent) so flat_write_idx cannot be
negative when used with HV and i_hv.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Line 2593: The assignment to h0_out_idx_ calling from_dlpack is misformatted;
reformat that line to satisfy ruff (apply ruff format or adjust
spacing/punctuation) so it matches the project's formatting rules (e.g., proper
spacing around the = and within the function call) in the h0_out_idx_ =
from_dlpack(...) statement; keep the same variable name h0_out_idx_ and function
call from_dlpack with arguments output_state_indices, assumed_align=32,
enable_tvm_ffi=True.

---

Duplicate comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2549-2553: When defaulting output_state_indices to
initial_state_indices, avoid assigning the same tensor with -1 padding because
the write path computes flat_write_idx and will write to negative offsets;
instead, in the branch where output_state_indices is None, create a clone of
initial_state_indices, replace negative values (e.g. -1) with 0 to preserve the
null-slot semantics for writes, and then ensure the tensor is converted to
torch.int32 (matching the existing dtype-check branch). Update the code that
sets output_state_indices so it uses output_state_indices =
initial_state_indices.clone(); output_state_indices[output_state_indices < 0] =
0; output_state_indices = output_state_indices.to(torch.int32) (or equivalent)
so flat_write_idx cannot be negative when used with HV and i_hv.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 2182-2193: The assertions that verify unmodified slots use
non-zero tolerances which can hide stray writes; change the checks on
pool_under_test vs pool_orig for read_indices and ~used_mask to use exact
equality (e.g., torch.equal or torch.testing.assert_close with atol=0, rtol=0)
so that pool_under_test[read_indices] and pool_under_test[~used_mask] must match
pool_orig exactly; update the two assert calls referencing pool_under_test,
pool_orig, read_indices, write_indices, and used_mask accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3392082e-e475-483c-bc9d-773a2995d6ff

📥 Commits

Reviewing files that changed from the base of the PR and between c8d21f0 and 58cc26d.

📒 Files selected for processing (4)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
  • tests/gdn/test_decode_delta_rule.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/gdn_decode.py

dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True)
o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True)
h0_idx_ = from_dlpack(initial_state_indices, assumed_align=32, enable_tvm_ffi=True)
h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix formatting to pass pre-commit checks.

The pipeline failure indicates this line needs reformatting per ruff format.

🔧 Apply ruff formatting
-    h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)
+    h0_out_idx_ = from_dlpack(
+        output_state_indices, assumed_align=32, enable_tvm_ffi=True
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True)
h0_out_idx_ = from_dlpack(
output_state_indices, assumed_align=32, enable_tvm_ffi=True
)
🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 2590-2593: pre-commit failed: ruff format (hook id: ruff-format) reformatted files. Diff shows formatting change in gated_delta_rule_mtp() for h0_out_idx_ = from_dlpack(output_state_indices, ...).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` at line 2593, The assignment
to h0_out_idx_ calling from_dlpack is misformatted; reformat that line to
satisfy ruff (apply ruff format or adjust spacing/punctuation) so it matches the
project's formatting rules (e.g., proper spacing around the = and within the
function call) in the h0_out_idx_ = from_dlpack(...) statement; keep the same
variable name h0_out_idx_ and function call from_dlpack with arguments
output_state_indices, assumed_align=32, enable_tvm_ffi=True.

@feldsherov
Copy link
Copy Markdown
Contributor Author

@saltyminty should I do anything here?

@kahyunnam
Copy link
Copy Markdown
Collaborator

@feldsherov this is failing pre-commit tests: https://github.com/flashinfer-ai/flashinfer/actions/runs/24164935099/job/70524339364?pr=2905

Can you please rerun pre-commit and push?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants